# -*- coding: utf_8 -*-
# Module for Malware Analysis
import io
import logging
import re
from pathlib import Path
from socket import (
    gaierror,
    gethostbyname,
)
from urllib.parse import urlparse

from django.conf import settings

import IP2Location

from mobsf.MobSF.utils import (
    append_scan_status,
    is_internet_available,
    settings_enabled,
    update_local_db,
)

logger = logging.getLogger(__name__)


class MalwareDomainCheck:

    def __init__(self):
        self.sig_dir = Path(settings.SIGNATURE_DIR)
        self.malwaredomainlist = self.sig_dir / 'malwaredomainlist'
        self.maltrail = self.sig_dir / 'maltrail-malware-domains.txt'
        self.iplocbin = self.sig_dir / 'IP2LOCATION-LITE-DB5.IPV6.BIN'
        self.result = {}
        self.domainlist = None
        self.IP2Loc = IP2Location.IP2Location()

    def update_maltrail_db(self):
        """Check for update in maltrail DB."""
        try:
            mal_db = self.maltrail
            resp = update_local_db(
                'Maltrail',
                settings.MALTRAIL_DB_URL, mal_db)
            if not resp:
                return
            # DB needs update
            # Check2: DB Syntax Changed
            lines = resp.decode('utf-8', 'ignore').splitlines()
            if len(lines) > 100:
                logger.info('Updating Maltrail Database')
                with open(mal_db, 'wb') as wfp:
                    wfp.write(resp)
            else:
                logger.warning('Unable to Update Maltrail DB')
        except Exception:
            logger.exception('[ERROR] Maltrail DB Update')

    def gelocation(self):
        """Perform Geolocation."""
        try:
            ofac_list = {
                'cuba', 'iran', 'north korea',
                'russia', 'syria', 'balkans',
                'belarus', 'myanmar', 'congo',
                'ethiopia', 'hong kong', 'iraq',
                'lebanon', 'libya', 'sudan',
                'venezuela', 'yemen', 'zimbabwe',
                'crimea', 'donetsk', 'luhansk',
                'afghanistan', 'china', 'ivory coast',
                'cyprus', 'eritrea', 'haiti',
                'liberia', 'somalia', 'sri lanka',
                'vietnam', 'south sudan',
            }
            self.IP2Loc.open(self.iplocbin)
            for domain in self.domainlist:
                # Tag Good Domains
                if domain not in self.result:
                    tmp_d = {}
                    tmp_d['bad'] = 'no'
                    self.result[domain] = tmp_d
                # GeoIP
                ip = None
                try:
                    ip = gethostbyname(domain)
                except (gaierror, UnicodeError):
                    pass
                if ip:
                    rec = self.IP2Loc.get_all(ip)
                    self.result[domain]['geolocation'] = rec.__dict__
                    country = rec.__dict__.get('country_long')
                    region = rec.__dict__.get('region')
                    city = rec.__dict__.get('city')
                    self.result[domain]['ofac'] = False
                    if country and country.lower() in ofac_list:
                        self.result[domain]['ofac'] = True
                    elif region and region.lower() in ofac_list:
                        self.result[domain]['ofac'] = True
                    elif city and city.lower() in ofac_list:
                        self.result[domain]['ofac'] = True
                else:
                    self.result[domain]['geolocation'] = None
        except Exception:
            logger.exception('Failed to Perform Geolocation')
        finally:
            if self.IP2Loc:
                self.IP2Loc.close()

    def malware_check(self):
        try:
            mal_db = self.malwaredomainlist
            with io.open(mal_db,
                         mode='r',
                         encoding='utf8',
                         errors='ignore') as flip:
                entry_list = flip.readlines()
            for entry in entry_list:
                enlist = entry.split('","')
                if len(enlist) > 5:
                    details_dict = {}
                    details_dict['domain_or_url'] = enlist[1]
                    details_dict['ip'] = enlist[2]
                    details_dict['desc'] = enlist[4]
                    details_dict['bad'] = 'yes'
                    dmn_url = details_dict['domain_or_url']
                    for domain in self.domainlist:
                        dmn_neturl = get_netloc(dmn_url)
                        if (((dmn_neturl == domain or dmn_neturl == domain[4:])
                                and (len(dmn_url) > 1))
                                or details_dict['ip'].startswith(domain)):
                            self.result[domain] = details_dict
        except Exception:
            logger.exception('[ERROR] Performing Malware check')

    def maltrail_check(self):
        try:
            mal_db = self.maltrail
            with io.open(mal_db,
                         mode='r',
                         encoding='utf8',
                         errors='ignore') as flip:
                entry_list = flip.read().splitlines()
            for domain in self.domainlist:
                if domain in entry_list:
                    self.result[domain] = {
                        'domain_or_url': domain,
                        'ip': 'N/A',
                        'desc': 'Malicious Domain tagged by Maltrail',
                        'bad': 'yes',
                    }
        except Exception:
            logger.exception('[ERROR] Performing Maltrail Check')

    def update(self):
        if is_internet_available():
            self.update_maltrail_db()
        else:
            logger.warning('Internet not available. '
                           'Skipping Malware Database Update.')

    def scan(self, checksum, urls):
        if not settings_enabled('DOMAIN_MALWARE_SCAN'):
            logger.info('Domain Malware check disabled in settings')
            return self.result
        msg = 'Performing Malware check on extracted domains'
        append_scan_status(checksum, msg)
        self.domainlist = get_domains(urls)
        if self.domainlist:
            self.update()
            self.malware_check()
            self.maltrail_check()
            self.gelocation()
        return self.result


# Helper Functions

def verify_domain(checkeddom):
    try:
        if (len(checkeddom) > 2
                and '.' in checkeddom
                and (checkeddom.endswith('.') is False
                     and re.search('[a-zA-Z0-9]', checkeddom))):
            return True
        else:
            return False
    except Exception:
        logger.exception('[ERROR] Verifying Domain')


def get_netloc(url):
    """Get Domain."""
    try:
        domain = ''
        parse_uri = urlparse(url)
        if not parse_uri.scheme:
            url = '//' + url
            parse_uri = urlparse(url)
        domain = '{uri.netloc}'.format(uri=parse_uri)
        if verify_domain(domain):
            return domain
    except Exception:
        logger.exception('[ERROR] Extracting Domain form URL')


def sanitize_domain(domain):
    """Sanitize domain to be RFC1034 compliant."""
    domain = domain.split('_')[0]
    domain = re.sub(r'[^\w^\.^\-]', '', domain)
    if domain.startswith('-'):
        domain = sanitize_domain(domain[1:])
    elif domain.endswith('-'):
        domain = sanitize_domain(domain[:-1])
    return domain


def get_domains(urls):
    """Get Domains."""
    try:
        domains = set()
        for url in urls:
            parse_uri = urlparse(url)
            if not parse_uri.scheme:
                url = '//' + url
                parse_uri = urlparse(url)
            domain = sanitize_domain(
                '{uri.hostname}'.format(uri=parse_uri))
            if verify_domain(domain):
                domains.add(domain)
        return domains
    except Exception:
        logger.exception('[ERROR] Extracting Domain form URL')
